#!/usr/bin/env python3
"""
T3 helper — sanity-check stellar-mass bins (KiDS DR4 Bright + LePhare).

Fixes:
- FITS endianness (astropy Table.read → to_pandas, then native-endian cast)
- Robust mass semantics (auto-detect log10 vs linear; NEVER double-log)
- Drops LePhare sentinel values (e.g., -99) and implausible logs (<6 or >13)
- Stable ID joins via string keys

Outputs:
- Quantiles for Mstar_log10 (cleaned)
- Counts per frozen T3 mass bin [10.2,10.5), [10.5,10.8), [10.8,11.1)
"""

import numpy as np
import pandas as pd
from astropy.io import fits
from astropy.table import Table

BRIGHT = "data/KiDS_DR4_brightsample.fits"
LEPH   = "data/KiDS_DR4_brightsample_LePhare.fits"
MS_EDGES = [10.2, 10.5, 10.8, 11.1]  # log10(M/Msun) bins (frozen)


def _first_table_hdu_index(hdul):
    for i, hdu in enumerate(hdul):
        if isinstance(hdu, fits.BinTableHDU):
            return i
    return None


def fits_to_df(path: str) -> pd.DataFrame:
    """Read a FITS binary table into pandas with native endianness."""
    with fits.open(path, memmap=True) as hdul:
        idx = _first_table_hdu_index(hdul)
        if idx is None:
            raise SystemExit(f"No BinTableHDU found in {path}")
    t = Table.read(path, format="fits", hdu=idx, memmap=True)
    df = t.to_pandas()
    for c in df.select_dtypes(include=[np.number]).columns:
        arr = df[c].to_numpy()
        if arr.dtype.byteorder in (">",):
            df[c] = arr.byteswap().newbyteorder()
    return df


def pick_mass_series(df_leph: pd.DataFrame) -> tuple[str, pd.Series]:
    """Choose a mass column with many finite values."""
    for col in ["MASS_MED", "MASS_BEST", "LOGMASS", "LOGMSTAR", "MASS_TOT", "MSTAR"]:
        if col in df_leph.columns:
            s = pd.to_numeric(df_leph[col], errors="coerce")
            if s.notna().sum() > 1000:
                return col, s
    raise SystemExit("No usable mass column found in LePhare file.")


def as_log10(series: pd.Series) -> pd.Series:
    """
    Return log10(M/Msun), auto-detecting whether input is already log10.
    Then drop sentinels and implausible logs.
    """
    s = pd.to_numeric(series, errors="coerce").replace([np.inf, -np.inf], np.nan)
    if s.notna().sum() == 0:
        return s
    q50 = np.nanquantile(s, 0.5)
    # detect semantics
    if 6.0 <= q50 <= 13.0:        # typical log10(M)
        logm = s
    elif 1e7 <= q50 <= 1e13:      # linear Msun
        logm = np.log10(s)
    else:
        logm = s
    # clean: drop sentinels/implausible ranges (e.g., -99) before reporting/binning
    logm = logm.mask(~np.isfinite(logm) | (logm < 6.0) | (logm > 13.0))
    return logm


def main():
    # Load tables (endianness-safe)
    B = fits_to_df(BRIGHT)
    L = fits_to_df(LEPH)

    need = ["ID", "RAJ2000", "DECJ2000", "zphot_ANNz2", "MASK"]
    missing = [c for c in need if c not in B.columns]
    if missing:
        raise SystemExit(f"Bright sample missing columns: {missing}")

    _, mass_raw = pick_mass_series(L)
    Mlog = as_log10(mass_raw)

    # Join by ID via normalized string key
    B2 = B[need].copy()
    B2["ID_key"] = B2["ID"].astype(str).str.strip()
    L2 = L[["ID"]].copy()
    L2["ID_key"] = L2["ID"].astype(str).str.strip()
    L2["Mstar_log10"] = Mlog

    D = B2.merge(L2[["ID_key", "Mstar_log10"]], on="ID_key", how="inner")
    D = D[D["Mstar_log10"].notna()].copy()

    # Report quantiles and per-bin counts
    q = np.nanquantile(D["Mstar_log10"], [0.05, 0.25, 0.5, 0.75, 0.95])
    print(f"[INFO] Lenses after join & cleaned masses: {len(D):,}")
    print(
        "[INFO] Mstar_log10 quantiles [5/25/50/75/95%]: "
        f"[{q[0]:.3f} {q[1]:.3f} {q[2]:.3f} {q[3]:.3f} {q[4]:.3f}]"
    )

    D["Mbin"] = pd.cut(D["Mstar_log10"], MS_EDGES, right=False)
    counts = D["Mbin"].value_counts().sort_index()
    print("\n[INFO] Counts per M* bin (10.2–10.5, 10.5–10.8, 10.8–11.1):")
    print(counts.rename("count"))
    print(f"\n[INFO] Total usable-in-bins: {int(counts.sum()):,}")

if __name__ == "__main__":
    main()
